{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Understanding the vanishing gradient problem through visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There're reasons why deep neural network could work very well, while few people get a promising result or make it possible by simply make their neural network *deep*.\n",
    "\n",
    "* Computational power and data grow tremendously. People need more complex model and faster computer to make it feasible.\n",
    "* Realize and understand the difficulties associated with training a deep model.\n",
    "\n",
    "In this tutorial, we would like to show you some insights of the techniques that researchers find useful in training a deep model, using MXNet and its visualizing tool -- TensorBoard.\n",
    "\n",
    "Let’s recap some of the relevant issues on training a deep model:\n",
    "\n",
    "* Weight initialization.  If you initialize the network with random and small weights, when you look at the gradients down the top layer, you would find they’re getting smaller and smaller, then the first layer almost doesn’t change as the gradients are too small to make a significant update. Without a chance to learn the first layer effectively, it's impossible to update and learn a good deep model.\n",
    "* Nonlinearity activation. When people use `sigmoid` or `tanh` as activation function, the gradient, same as the above, is getting smaller and smaller. Just remind the formula of the parameter updates and the gradient."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data and DataIter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def download(data_dir):\n",
    "    if not os.path.isdir(data_dir):\n",
    "        os.system('mkdir ' + data_dir)\n",
    "    os.chdir(data_dir)\n",
    "    if (not os.path.exists('train-images-idx3-ubyte')) or \\\n",
    "       (not os.path.exists('train-labels-idx1-ubyte')) or \\\n",
    "       (not os.path.exists('t10k-images-idx3-ubyte')) or \\\n",
    "       (not os.path.exists('t10k-labels-idx1-ubyte')):\n",
    "           os.system('wget http://data.mxnet.io/mxnet/data/mnist.zip')\n",
    "           os.system('unzip mnist.zip; rm mnist.zip')\n",
    "    os.chdir('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_iterator(data_shape):\n",
    "    def get_iterator_impl(args, kv):\n",
    "        data_dir = args.data_dir\n",
    "        # if Windows\n",
    "        if os.name == \"nt\":\n",
    "            data_dir = data_dir[:-1] + \"\\\\\"\n",
    "        if '://' not in args.data_dir:\n",
    "            download(data_dir)\n",
    "        flat = False if len(data_shape) == 3 else True\n",
    "\n",
    "        train           = mx.io.MNISTIter(\n",
    "            image       = data_dir + \"train-images-idx3-ubyte\",\n",
    "            label       = data_dir + \"train-labels-idx1-ubyte\",\n",
    "            input_shape = data_shape,\n",
    "            batch_size  = args.batch_size,\n",
    "            shuffle     = True,\n",
    "            flat        = flat,\n",
    "            num_parts   = kv.num_workers,\n",
    "            part_index  = kv.rank)\n",
    "\n",
    "        val = mx.io.MNISTIter(\n",
    "            image       = data_dir + \"t10k-images-idx3-ubyte\",\n",
    "            label       = data_dir + \"t10k-labels-idx1-ubyte\",\n",
    "            input_shape = data_shape,\n",
    "            batch_size  = args.batch_size,\n",
    "            flat        = flat,\n",
    "            num_parts   = kv.num_workers,\n",
    "            part_index  = kv.rank)\n",
    "\n",
    "        return (train, val)\n",
    "    return get_iterator_impl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Network Structure\n",
    "\n",
    "Here's the network structure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_mlp(acti=\"relu\"):\n",
    "    \"\"\"\n",
    "    multi-layer perceptron\n",
    "    \"\"\"\n",
    "    data = mx.symbol.Variable('data')\n",
    "    fc   = mx.symbol.FullyConnected(data = data, name='fc', num_hidden=512)\n",
    "    act  = mx.symbol.Activation(data = fc, name='act', act_type=acti)\n",
    "    fc0  = mx.symbol.FullyConnected(data = act, name='fc0', num_hidden=256)\n",
    "    act0 = mx.symbol.Activation(data = fc0, name='act0', act_type=acti)\n",
    "    fc1  = mx.symbol.FullyConnected(data = act0, name='fc1', num_hidden=128)\n",
    "    act1 = mx.symbol.Activation(data = fc1, name='act1', act_type=acti)\n",
    "    fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)\n",
    "    act2 = mx.symbol.Activation(data = fc2, name='act2', act_type=acti)\n",
    "    fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=32)\n",
    "    act3 = mx.symbol.Activation(data = fc3, name='act3', act_type=acti)\n",
    "    fc4  = mx.symbol.FullyConnected(data = act3, name='fc4', num_hidden=16)\n",
    "    act4 = mx.symbol.Activation(data = fc4, name='act4', act_type=acti)\n",
    "    fc5  = mx.symbol.FullyConnected(data = act4, name='fc5', num_hidden=10)\n",
    "    mlp  = mx.symbol.SoftmaxOutput(data = fc5, name = 'softmax')\n",
    "    return mlp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you might already notice, we intentionally add more layers than usual, as the vanished gradient problem becomes severer as the network goes deeper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment Setting\n",
    "\n",
    "Here we create a simple MLP for cifar10 dataset and visualize the learning processing through loss/accuracy, and its gradient distributions, by changing its initialization and activation setting.\n",
    "\n",
    "## General Setting \n",
    "\n",
    "We adopt MLP as our model and run our experiment in MNIST dataset. Then we'll visualize the weight and gradient of a layer using `Monitor` in MXNet and `Histogram` in TensorBoard.\n",
    "\n",
    "\n",
    "### Weight Initialization\n",
    "\n",
    "The weight initialization also has `uniform` and `xavier`. \n",
    "\n",
    "```python\n",
    "if args.init == 'uniform':\n",
    "        init = mx.init.Uniform(0.1)\n",
    "if args.init == 'xavier':\n",
    "    init = mx.init.Xavier(factor_type=\"in\", magnitude=2.34)\n",
    "```\n",
    "\n",
    "Note that we intentionally choose a near zero setting in `uniform`. \n",
    "\n",
    "### Activation  Function\n",
    "\n",
    "We would compare two different activations, `sigmoid` and `relu`. \n",
    "\n",
    "```python\n",
    "# acti = sigmoid or relu.\n",
    "act  = mx.symbol.Activation(data = fc, name='act', act_type=acti)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logging with TensorBoard and Monitor\n",
    "\n",
    "In order to monitor the weight and gradient of this network in different settings, we could use MXNet's `monitor` for logging and `TensorBoard` for visualization.\n",
    "\n",
    "### Usage\n",
    "\n",
    "Here's a code snippet from `train_model.py`:\n",
    "\n",
    "```python\n",
    "import mxnet as mx\n",
    "from tensorboard import summary\n",
    "from tensorboard import FileWriter\n",
    "\n",
    "# where to keep your TensorBoard logging file\n",
    "logdir = './logs/'\n",
    "summary_writer = FileWriter(logdir)\n",
    "\n",
    "# mx.mon.Monitor's callback \n",
    "def get_gradient(g):\n",
    "    # get flatten list\n",
    "    grad = g.asnumpy().flatten()\n",
    "    # logging using tensorboard, use histogram type.\n",
    "    s = summary.histogram('fc_backward_weight', grad)\n",
    "    summary_writer.add_summary(s)\n",
    "    return mx.nd.norm(g)/np.sqrt(g.size)\n",
    "\n",
    "mon = mx.mon.Monitor(int(args.num_examples/args.batch_size), get_gradient, pattern='fc_backward_weight')  # get the gradient passed to the first fully-connnected layer.\n",
    "\n",
    "# training\n",
    "model.fit(\n",
    "        X                  = train,\n",
    "        eval_data          = val,\n",
    "        eval_metric        = eval_metrics,\n",
    "        kvstore            = kv,\n",
    "        monitor            = mon,\n",
    "        epoch_end_callback = checkpoint)\n",
    "\n",
    "# close summary_writer\n",
    "summary_writer.close()\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import argparse\n",
    "import os, sys\n",
    "\n",
    "def parse_args(init_type, name):\n",
    "    parser = argparse.ArgumentParser(description='train an image classifer on mnist')\n",
    "    parser.add_argument('--network', type=str, default='mlp',\n",
    "                        choices = ['mlp', 'lenet', 'lenet-stn'],\n",
    "                        help = 'the cnn to use')\n",
    "    parser.add_argument('--data-dir', type=str, default='mnist/',\n",
    "                        help='the input data directory')\n",
    "    parser.add_argument('--gpus', type=str,\n",
    "                        help='the gpus will be used, e.g \"0,1,2,3\"')\n",
    "    parser.add_argument('--num-examples', type=int, default=60000,\n",
    "                        help='the number of training examples')\n",
    "    parser.add_argument('--batch-size', type=int, default=128,\n",
    "                        help='the batch size')\n",
    "    parser.add_argument('--lr', type=float, default=.1,\n",
    "                        help='the initial learning rate')\n",
    "    parser.add_argument('--model-prefix', type=str,\n",
    "                        help='the prefix of the model to load/save')\n",
    "    parser.add_argument('--save-model-prefix', type=str,\n",
    "                        help='the prefix of the model to save')\n",
    "    parser.add_argument('--num-epochs', type=int, default=10,\n",
    "                        help='the number of training epochs')\n",
    "    parser.add_argument('--load-epoch', type=int,\n",
    "                        help=\"load the model on an epoch using the model-prefix\")\n",
    "    parser.add_argument('--kv-store', type=str, default='local',\n",
    "                        help='the kvstore type')\n",
    "    parser.add_argument('--lr-factor', type=float, default=1,\n",
    "                        help='times the lr with a factor for every lr-factor-epoch epoch')\n",
    "    parser.add_argument('--lr-factor-epoch', type=float, default=1,\n",
    "                        help='the number of epoch to factor the lr, could be .5')\n",
    "    parser.add_argument('--init', type=str, default=init_type,\n",
    "                        help='the weight initialization method')\n",
    "    parser.add_argument('--name', type=str, default=name,\n",
    "                        help='name for summary.histogram for gradient/weight logging')\n",
    "    return parser.parse_args(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import logging\n",
    "import os\n",
    "import numpy as np\n",
    "from tensorboard import summary\n",
    "from tensorboard import FileWriter\n",
    "\n",
    "def fit(args, network, data_loader, batch_end_callback=None):\n",
    "    # kvstore\n",
    "    kv = mx.kvstore.create(args.kv_store)\n",
    "\n",
    "    # logging\n",
    "    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'\n",
    "    if 'log_file' in args and args.log_file is not None:\n",
    "        log_file = args.log_file\n",
    "        log_dir = args.log_dir\n",
    "        log_file_full_name = os.path.join(log_dir, log_file)\n",
    "        if not os.path.exists(log_dir):\n",
    "            os.mkdir(log_dir)\n",
    "        logger = logging.getLogger()\n",
    "        handler = logging.FileHandler(log_file_full_name)\n",
    "        formatter = logging.Formatter(head)\n",
    "        handler.setFormatter(formatter)\n",
    "        logger.addHandler(handler)\n",
    "        logger.setLevel(logging.DEBUG)\n",
    "        logger.info('start with arguments %s', args)\n",
    "    else:\n",
    "        logging.basicConfig(level=logging.DEBUG, format=head)\n",
    "        logging.info('start with arguments %s', args)\n",
    "\n",
    "    # load model\n",
    "    model_prefix = args.model_prefix\n",
    "    if model_prefix is not None:\n",
    "        model_prefix += \"-%d\" % (kv.rank)\n",
    "    model_args = {}\n",
    "    if args.load_epoch is not None:\n",
    "        assert model_prefix is not None\n",
    "        tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)\n",
    "        model_args = {'arg_params' : tmp.arg_params,\n",
    "                      'aux_params' : tmp.aux_params,\n",
    "                      'begin_epoch' : args.load_epoch}\n",
    "        # TODO: check epoch_size for 'dist_sync'\n",
    "        epoch_size = args.num_examples / args.batch_size\n",
    "        model_args['begin_num_update'] = epoch_size * args.load_epoch\n",
    "\n",
    "    # save model\n",
    "    save_model_prefix = args.save_model_prefix\n",
    "    if save_model_prefix is None:\n",
    "        save_model_prefix = model_prefix\n",
    "    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)\n",
    "\n",
    "    # data\n",
    "    (train, val) = data_loader(args, kv)\n",
    "\n",
    "    # train\n",
    "    devs = [mx.cpu(i) for i in range(4)] if args.gpus is None else [\n",
    "        mx.gpu(int(i)) for i in args.gpus.split(',')]\n",
    "\n",
    "    epoch_size = args.num_examples / args.batch_size\n",
    "\n",
    "    if args.kv_store == 'dist_sync':\n",
    "        epoch_size /= kv.num_workers\n",
    "        model_args['epoch_size'] = epoch_size\n",
    "\n",
    "    if 'lr_factor' in args and args.lr_factor < 1:\n",
    "        model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(\n",
    "            step = max(int(epoch_size * args.lr_factor_epoch), 1),\n",
    "            factor = args.lr_factor)\n",
    "\n",
    "    if 'clip_gradient' in args and args.clip_gradient is not None:\n",
    "        model_args['clip_gradient'] = args.clip_gradient\n",
    "\n",
    "    # disable kvstore for single device\n",
    "    if 'local' in kv.type and (\n",
    "            args.gpus is None or len(args.gpus.split(',')) is 1):\n",
    "        kv = None\n",
    "    \n",
    "    if args.init == 'uniform':\n",
    "        init = mx.init.Uniform(0.1)\n",
    "    if args.init == 'normal':\n",
    "        init = mx.init.Normal(0,0.1)\n",
    "    if args.init == 'xavier':\n",
    "        init = mx.init.Xavier(factor_type=\"in\", magnitude=2.34)\n",
    "    model = mx.model.FeedForward(\n",
    "        ctx                = devs,\n",
    "        symbol             = network,\n",
    "        num_epoch          = args.num_epochs,\n",
    "        learning_rate      = args.lr,\n",
    "        momentum           = 0.9,\n",
    "        wd                 = 0.00001,\n",
    "        initializer        = init,\n",
    "        **model_args)\n",
    "\n",
    "    eval_metrics = ['accuracy']\n",
    "    ## TopKAccuracy only allows top_k > 1\n",
    "    for top_k in [5]:\n",
    "        eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))\n",
    "\n",
    "    if batch_end_callback is not None:\n",
    "        if not isinstance(batch_end_callback, list):\n",
    "            batch_end_callback = [batch_end_callback]\n",
    "    else:\n",
    "        batch_end_callback = []\n",
    "    batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))\n",
    "    \n",
    "    logdir = './logs/'\n",
    "    summary_writer = FileWriter(logdir)\n",
    "    def get_grad(g):\n",
    "        # logging using tensorboard\n",
    "        grad = g.asnumpy().flatten()\n",
    "        s = summary.histogram(args.name, grad)\n",
    "        summary_writer.add_summary(s)\n",
    "        return mx.nd.norm(g)/np.sqrt(g.size)\n",
    "    mon = mx.mon.Monitor(int(args.num_examples/args.batch_size), get_grad, pattern='fc_backward_weight')  # get weight of first fully-connnected layer\n",
    "    \n",
    "    model.fit(\n",
    "        X                  = train,\n",
    "        eval_data          = val,\n",
    "        eval_metric        = eval_metrics,\n",
    "        kvstore            = kv,\n",
    "        monitor            = mon,\n",
    "        epoch_end_callback = checkpoint)\n",
    "\n",
    "    summary_writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What to expect?\n",
    "\n",
    "If a setting suffers from an vanish gradient problem, the gradients passed from the top should be very close to zero, and the weight of the network barely change/update. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uniform and Sigmoid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-01-12 19:30:11,856 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='uniform', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='uniform_sigmoid', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)\n",
      "2017-01-12 19:30:14,983 Node[0] \u001b[91m[Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n",
      "2017-01-12 19:30:14,990 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]\n",
      "2017-01-12 19:30:17,453 Node[0] Batch:       1 fc_backward_weight             5.1907e-07\t\n",
      "2017-01-12 19:30:17,454 Node[0] Batch:       1 fc_backward_weight             4.2085e-07\t\n",
      "2017-01-12 19:30:17,455 Node[0] Batch:       1 fc_backward_weight             4.31894e-07\t\n",
      "2017-01-12 19:30:17,456 Node[0] Batch:       1 fc_backward_weight             5.80652e-07\t\n",
      "2017-01-12 19:30:22,885 Node[0] Epoch[0] Resetting Data Iterator\n",
      "2017-01-12 19:30:22,888 Node[0] Epoch[0] Time cost=7.839\n",
      "2017-01-12 19:30:23,215 Node[0] Epoch[0] Validation-accuracy=0.105769\n",
      "2017-01-12 19:30:23,216 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.509115\n",
      "2017-01-12 19:30:25,728 Node[0] Batch:     469 fc_backward_weight             5.15008e-07\t\n",
      "2017-01-12 19:30:25,730 Node[0] Batch:     469 fc_backward_weight             5.52044e-07\t\n",
      "2017-01-12 19:30:25,730 Node[0] Batch:     469 fc_backward_weight             4.48535e-07\t\n",
      "2017-01-12 19:30:25,732 Node[0] Batch:     469 fc_backward_weight             5.8659e-07\t\n",
      "2017-01-12 19:30:31,356 Node[0] Epoch[1] Resetting Data Iterator\n",
      "2017-01-12 19:30:31,357 Node[0] Epoch[1] Time cost=8.140\n",
      "2017-01-12 19:30:31,868 Node[0] Epoch[1] Validation-accuracy=0.105769\n",
      "2017-01-12 19:30:31,869 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.504507\n",
      "2017-01-12 19:30:34,348 Node[0] Batch:     937 fc_backward_weight             5.96259e-07\t\n",
      "2017-01-12 19:30:34,349 Node[0] Batch:     937 fc_backward_weight             5.97974e-07\t\n",
      "2017-01-12 19:30:34,350 Node[0] Batch:     937 fc_backward_weight             4.51892e-07\t\n",
      "2017-01-12 19:30:34,351 Node[0] Batch:     937 fc_backward_weight             6.5213e-07\t\n",
      "2017-01-12 19:30:39,779 Node[0] Epoch[2] Resetting Data Iterator\n",
      "2017-01-12 19:30:39,780 Node[0] Epoch[2] Time cost=7.910\n",
      "2017-01-12 19:30:40,325 Node[0] Epoch[2] Validation-accuracy=0.105769\n",
      "2017-01-12 19:30:40,327 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.510216\n",
      "2017-01-12 19:30:42,989 Node[0] Batch:    1405 fc_backward_weight             6.52871e-07\t\n",
      "2017-01-12 19:30:42,989 Node[0] Batch:    1405 fc_backward_weight             6.20821e-07\t\n",
      "2017-01-12 19:30:42,990 Node[0] Batch:    1405 fc_backward_weight             4.46476e-07\t\n",
      "2017-01-12 19:30:42,992 Node[0] Batch:    1405 fc_backward_weight             7.53641e-07\t\n",
      "2017-01-12 19:30:48,245 Node[0] Epoch[3] Resetting Data Iterator\n",
      "2017-01-12 19:30:48,246 Node[0] Epoch[3] Time cost=7.917\n",
      "2017-01-12 19:30:48,555 Node[0] Epoch[3] Validation-accuracy=0.105769\n",
      "2017-01-12 19:30:48,556 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.510216\n",
      "2017-01-12 19:30:51,142 Node[0] Batch:    1873 fc_backward_weight             6.63064e-07\t\n",
      "2017-01-12 19:30:51,143 Node[0] Batch:    1873 fc_backward_weight             6.33577e-07\t\n",
      "2017-01-12 19:30:51,145 Node[0] Batch:    1873 fc_backward_weight             4.2922e-07\t\n",
      "2017-01-12 19:30:51,147 Node[0] Batch:    1873 fc_backward_weight             8.31741e-07\t\n",
      "2017-01-12 19:30:56,116 Node[0] Epoch[4] Resetting Data Iterator\n",
      "2017-01-12 19:30:56,117 Node[0] Epoch[4] Time cost=7.559\n",
      "2017-01-12 19:30:56,430 Node[0] Epoch[4] Validation-accuracy=0.103666\n",
      "2017-01-12 19:30:56,431 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:30:59,020 Node[0] Batch:    2341 fc_backward_weight             6.47525e-07\t\n",
      "2017-01-12 19:30:59,021 Node[0] Batch:    2341 fc_backward_weight             6.37593e-07\t\n",
      "2017-01-12 19:30:59,021 Node[0] Batch:    2341 fc_backward_weight             4.12299e-07\t\n",
      "2017-01-12 19:30:59,022 Node[0] Batch:    2341 fc_backward_weight             8.71203e-07\t\n",
      "2017-01-12 19:31:04,117 Node[0] Epoch[5] Resetting Data Iterator\n",
      "2017-01-12 19:31:04,118 Node[0] Epoch[5] Time cost=7.686\n",
      "2017-01-12 19:31:04,505 Node[0] Epoch[5] Validation-accuracy=0.103666\n",
      "2017-01-12 19:31:04,505 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:31:07,535 Node[0] Batch:    2809 fc_backward_weight             6.23424e-07\t\n",
      "2017-01-12 19:31:07,536 Node[0] Batch:    2809 fc_backward_weight             6.33117e-07\t\n",
      "2017-01-12 19:31:07,537 Node[0] Batch:    2809 fc_backward_weight             3.99334e-07\t\n",
      "2017-01-12 19:31:07,539 Node[0] Batch:    2809 fc_backward_weight             8.78155e-07\t\n",
      "2017-01-12 19:31:13,145 Node[0] Epoch[6] Resetting Data Iterator\n",
      "2017-01-12 19:31:13,149 Node[0] Epoch[6] Time cost=8.642\n",
      "2017-01-12 19:31:13,733 Node[0] Epoch[6] Validation-accuracy=0.107472\n",
      "2017-01-12 19:31:13,735 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:31:16,660 Node[0] Batch:    3277 fc_backward_weight             5.97921e-07\t\n",
      "2017-01-12 19:31:16,661 Node[0] Batch:    3277 fc_backward_weight             6.22105e-07\t\n",
      "2017-01-12 19:31:16,662 Node[0] Batch:    3277 fc_backward_weight             3.89208e-07\t\n",
      "2017-01-12 19:31:16,663 Node[0] Batch:    3277 fc_backward_weight             8.6379e-07\t\n",
      "2017-01-12 19:31:21,486 Node[0] Epoch[7] Resetting Data Iterator\n",
      "2017-01-12 19:31:21,487 Node[0] Epoch[7] Time cost=7.742\n",
      "2017-01-12 19:31:21,781 Node[0] Epoch[7] Validation-accuracy=0.109776\n",
      "2017-01-12 19:31:21,782 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:31:24,268 Node[0] Batch:    3745 fc_backward_weight             5.73259e-07\t\n",
      "2017-01-12 19:31:24,270 Node[0] Batch:    3745 fc_backward_weight             6.06878e-07\t\n",
      "2017-01-12 19:31:24,270 Node[0] Batch:    3745 fc_backward_weight             3.80379e-07\t\n",
      "2017-01-12 19:31:24,271 Node[0] Batch:    3745 fc_backward_weight             8.37382e-07\t\n",
      "2017-01-12 19:31:29,187 Node[0] Epoch[8] Resetting Data Iterator\n",
      "2017-01-12 19:31:29,188 Node[0] Epoch[8] Time cost=7.405\n",
      "2017-01-12 19:31:29,508 Node[0] Epoch[8] Validation-accuracy=0.105970\n",
      "2017-01-12 19:31:29,509 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.512620\n",
      "2017-01-12 19:31:31,994 Node[0] Batch:    4213 fc_backward_weight             5.49988e-07\t\n",
      "2017-01-12 19:31:31,995 Node[0] Batch:    4213 fc_backward_weight             5.89305e-07\t\n",
      "2017-01-12 19:31:31,996 Node[0] Batch:    4213 fc_backward_weight             3.71941e-07\t\n",
      "2017-01-12 19:31:31,997 Node[0] Batch:    4213 fc_backward_weight             8.05085e-07\t\n",
      "2017-01-12 19:31:37,268 Node[0] Epoch[9] Resetting Data Iterator\n",
      "2017-01-12 19:31:37,270 Node[0] Epoch[9] Time cost=7.760\n",
      "2017-01-12 19:31:37,707 Node[0] Epoch[9] Validation-accuracy=0.105970\n",
      "2017-01-12 19:31:37,708 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.512620\n"
     ]
    }
   ],
   "source": [
    "# Uniform and sigmoid\n",
    "args = parse_args('uniform', 'uniform_sigmoid')\n",
    "data_shape = (784, )\n",
    "net = get_mlp(\"sigmoid\")\n",
    "\n",
    "# train\n",
    "fit(args, net, get_iterator(data_shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you've seen, the metrics of `fc_backward_weight` is so close to zero, and it didn't change a lot during batchs.\n",
    "\n",
    "```\n",
    "2017-01-07 15:44:38,845 Node[0] Batch:       1 fc_backward_weight             5.1907e-07\t\n",
    "2017-01-07 15:44:38,846 Node[0] Batch:       1 fc_backward_weight             4.2085e-07\t\n",
    "2017-01-07 15:44:38,847 Node[0] Batch:       1 fc_backward_weight             4.31894e-07\t\n",
    "2017-01-07 15:44:38,848 Node[0] Batch:       1 fc_backward_weight             5.80652e-07\n",
    "\n",
    "2017-01-07 15:45:50,199 Node[0] Batch:    4213 fc_backward_weight             5.49988e-07\t\n",
    "2017-01-07 15:45:50,200 Node[0] Batch:    4213 fc_backward_weight             5.89305e-07\t\n",
    "2017-01-07 15:45:50,201 Node[0] Batch:    4213 fc_backward_weight             3.71941e-07\t\n",
    "2017-01-07 15:45:50,202 Node[0] Batch:    4213 fc_backward_weight             8.05085e-07\n",
    "```\n",
    "\n",
    "You might wonder why we have 4 different `fc_backward_weight`, cause we use 4 cpus."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Uniform and ReLu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-01-12 19:31:37,722 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='uniform', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='uniform_relu', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)\n",
      "2017-01-12 19:31:40,900 Node[0] \u001b[91m[Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n",
      "2017-01-12 19:31:40,906 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]\n",
      "2017-01-12 19:31:43,382 Node[0] Batch:       1 fc_backward_weight             0.000267409\t\n",
      "2017-01-12 19:31:43,383 Node[0] Batch:       1 fc_backward_weight             0.00031988\t\n",
      "2017-01-12 19:31:43,384 Node[0] Batch:       1 fc_backward_weight             0.000306785\t\n",
      "2017-01-12 19:31:43,385 Node[0] Batch:       1 fc_backward_weight             0.000347533\t\n",
      "2017-01-12 19:31:48,518 Node[0] Epoch[0] Resetting Data Iterator\n",
      "2017-01-12 19:31:48,519 Node[0] Epoch[0] Time cost=7.595\n",
      "2017-01-12 19:31:48,821 Node[0] Epoch[0] Validation-accuracy=0.694912\n",
      "2017-01-12 19:31:48,822 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.976362\n",
      "2017-01-12 19:31:51,256 Node[0] Batch:     469 fc_backward_weight             0.0527437\t\n",
      "2017-01-12 19:31:51,257 Node[0] Batch:     469 fc_backward_weight             0.0421219\t\n",
      "2017-01-12 19:31:51,259 Node[0] Batch:     469 fc_backward_weight             0.0495309\t\n",
      "2017-01-12 19:31:51,260 Node[0] Batch:     469 fc_backward_weight             0.0421051\t\n",
      "2017-01-12 19:31:56,442 Node[0] Epoch[1] Resetting Data Iterator\n",
      "2017-01-12 19:31:56,443 Node[0] Epoch[1] Time cost=7.619\n",
      "2017-01-12 19:31:56,723 Node[0] Epoch[1] Validation-accuracy=0.907652\n",
      "2017-01-12 19:31:56,724 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.986679\n",
      "2017-01-12 19:31:59,184 Node[0] Batch:     937 fc_backward_weight             0.0285753\t\n",
      "2017-01-12 19:31:59,185 Node[0] Batch:     937 fc_backward_weight             0.0520748\t\n",
      "2017-01-12 19:31:59,186 Node[0] Batch:     937 fc_backward_weight             0.0807526\t\n",
      "2017-01-12 19:31:59,187 Node[0] Batch:     937 fc_backward_weight             0.0502396\t\n",
      "2017-01-12 19:32:04,648 Node[0] Epoch[2] Resetting Data Iterator\n",
      "2017-01-12 19:32:04,649 Node[0] Epoch[2] Time cost=7.924\n",
      "2017-01-12 19:32:04,923 Node[0] Epoch[2] Validation-accuracy=0.921675\n",
      "2017-01-12 19:32:04,923 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.987380\n",
      "2017-01-12 19:32:07,411 Node[0] Batch:    1405 fc_backward_weight             0.0596137\t\n",
      "2017-01-12 19:32:07,412 Node[0] Batch:    1405 fc_backward_weight             0.145902\t\n",
      "2017-01-12 19:32:07,413 Node[0] Batch:    1405 fc_backward_weight             0.0783883\t\n",
      "2017-01-12 19:32:07,414 Node[0] Batch:    1405 fc_backward_weight             0.0810687\t\n",
      "2017-01-12 19:32:13,291 Node[0] Epoch[3] Resetting Data Iterator\n",
      "2017-01-12 19:32:13,292 Node[0] Epoch[3] Time cost=8.368\n",
      "2017-01-12 19:32:13,621 Node[0] Epoch[3] Validation-accuracy=0.947516\n",
      "2017-01-12 19:32:13,623 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.990084\n",
      "2017-01-12 19:32:16,028 Node[0] Batch:    1873 fc_backward_weight             0.113804\t\n",
      "2017-01-12 19:32:16,029 Node[0] Batch:    1873 fc_backward_weight             0.0355092\t\n",
      "2017-01-12 19:32:16,030 Node[0] Batch:    1873 fc_backward_weight             0.0510211\t\n",
      "2017-01-12 19:32:16,031 Node[0] Batch:    1873 fc_backward_weight             0.0461469\t\n",
      "2017-01-12 19:32:20,539 Node[0] Epoch[4] Resetting Data Iterator\n",
      "2017-01-12 19:32:20,541 Node[0] Epoch[4] Time cost=6.917\n",
      "2017-01-12 19:32:20,823 Node[0] Epoch[4] Validation-accuracy=0.949319\n",
      "2017-01-12 19:32:20,823 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.991587\n",
      "2017-01-12 19:32:23,312 Node[0] Batch:    2341 fc_backward_weight             0.0304884\t\n",
      "2017-01-12 19:32:23,313 Node[0] Batch:    2341 fc_backward_weight             0.0153732\t\n",
      "2017-01-12 19:32:23,314 Node[0] Batch:    2341 fc_backward_weight             0.0638052\t\n",
      "2017-01-12 19:32:23,315 Node[0] Batch:    2341 fc_backward_weight             0.0358958\t\n",
      "2017-01-12 19:32:27,721 Node[0] Epoch[5] Resetting Data Iterator\n",
      "2017-01-12 19:32:27,722 Node[0] Epoch[5] Time cost=6.897\n",
      "2017-01-12 19:32:28,116 Node[0] Epoch[5] Validation-accuracy=0.952224\n",
      "2017-01-12 19:32:28,117 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.991687\n",
      "2017-01-12 19:32:30,555 Node[0] Batch:    2809 fc_backward_weight             0.180743\t\n",
      "2017-01-12 19:32:30,556 Node[0] Batch:    2809 fc_backward_weight             0.0453026\t\n",
      "2017-01-12 19:32:30,558 Node[0] Batch:    2809 fc_backward_weight             0.0212601\t\n",
      "2017-01-12 19:32:30,558 Node[0] Batch:    2809 fc_backward_weight             0.0950233\t\n",
      "2017-01-12 19:32:36,190 Node[0] Epoch[6] Resetting Data Iterator\n",
      "2017-01-12 19:32:36,191 Node[0] Epoch[6] Time cost=8.074\n",
      "2017-01-12 19:32:36,548 Node[0] Epoch[6] Validation-accuracy=0.949219\n",
      "2017-01-12 19:32:36,552 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.992889\n",
      "2017-01-12 19:32:39,129 Node[0] Batch:    3277 fc_backward_weight             0.0977342\t\n",
      "2017-01-12 19:32:39,130 Node[0] Batch:    3277 fc_backward_weight             0.0354421\t\n",
      "2017-01-12 19:32:39,131 Node[0] Batch:    3277 fc_backward_weight             0.00394049\t\n",
      "2017-01-12 19:32:39,132 Node[0] Batch:    3277 fc_backward_weight             0.0402826\t\n",
      "2017-01-12 19:32:44,758 Node[0] Epoch[7] Resetting Data Iterator\n",
      "2017-01-12 19:32:44,759 Node[0] Epoch[7] Time cost=8.206\n",
      "2017-01-12 19:32:45,051 Node[0] Epoch[7] Validation-accuracy=0.956130\n",
      "2017-01-12 19:32:45,052 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.993389\n",
      "2017-01-12 19:32:47,585 Node[0] Batch:    3745 fc_backward_weight             0.012503\t\n",
      "2017-01-12 19:32:47,586 Node[0] Batch:    3745 fc_backward_weight             0.064014\t\n",
      "2017-01-12 19:32:47,587 Node[0] Batch:    3745 fc_backward_weight             0.0158367\t\n",
      "2017-01-12 19:32:47,588 Node[0] Batch:    3745 fc_backward_weight             0.00945755\t\n",
      "2017-01-12 19:32:53,593 Node[0] Epoch[8] Resetting Data Iterator\n",
      "2017-01-12 19:32:53,594 Node[0] Epoch[8] Time cost=8.541\n",
      "2017-01-12 19:32:54,017 Node[0] Epoch[8] Validation-accuracy=0.957031\n",
      "2017-01-12 19:32:54,018 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.992788\n",
      "2017-01-12 19:32:56,820 Node[0] Batch:    4213 fc_backward_weight             0.0226081\t\n",
      "2017-01-12 19:32:56,821 Node[0] Batch:    4213 fc_backward_weight             0.0039793\t\n",
      "2017-01-12 19:32:56,822 Node[0] Batch:    4213 fc_backward_weight             0.0306151\t\n",
      "2017-01-12 19:32:56,823 Node[0] Batch:    4213 fc_backward_weight             0.00818676\t\n",
      "2017-01-12 19:33:02,386 Node[0] Epoch[9] Resetting Data Iterator\n",
      "2017-01-12 19:33:02,387 Node[0] Epoch[9] Time cost=8.368\n",
      "2017-01-12 19:33:02,666 Node[0] Epoch[9] Validation-accuracy=0.959736\n",
      "2017-01-12 19:33:02,667 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.991987\n"
     ]
    }
   ],
   "source": [
    "# Uniform and sigmoid\n",
    "args = parse_args('uniform', 'uniform_relu')\n",
    "data_shape = (784, )\n",
    "net = get_mlp(\"relu\")\n",
    "\n",
    "# train\n",
    "fit(args, net, get_iterator(data_shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Even we have a \"poor\" initialization, the model could still converge quickly with proper activation function. And its magnitude has significant difference.\n",
    "\n",
    "```\n",
    "2017-01-07 15:54:12,286 Node[0] Batch:       1 fc_backward_weight             0.000267409\t\n",
    "2017-01-07 15:54:12,287 Node[0] Batch:       1 fc_backward_weight             0.00031988\t\n",
    "2017-01-07 15:54:12,288 Node[0] Batch:       1 fc_backward_weight             0.000306785\t\n",
    "2017-01-07 15:54:12,289 Node[0] Batch:       1 fc_backward_weight             0.000347533\n",
    "\n",
    "2017-01-07 15:55:25,936 Node[0] Batch:    4213 fc_backward_weight             0.0226081\t\n",
    "2017-01-07 15:55:25,937 Node[0] Batch:    4213 fc_backward_weight             0.0039793\t\n",
    "2017-01-07 15:55:25,937 Node[0] Batch:    4213 fc_backward_weight             0.0306151\t\n",
    "2017-01-07 15:55:25,938 Node[0] Batch:    4213 fc_backward_weight             0.00818676\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Xavier and Sigmoid  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2017-01-12 19:33:02,682 Node[0] start with arguments Namespace(batch_size=128, data_dir='mnist/', gpus=None, init='xavier', kv_store='local', load_epoch=None, lr=0.1, lr_factor=1, lr_factor_epoch=1, model_prefix=None, name='xavier_sigmoid', network='mlp', num_epochs=10, num_examples=60000, save_model_prefix=None)\n",
      "2017-01-12 19:33:05,863 Node[0] \u001b[91m[Deprecation Warning] mxnet.model.FeedForward has been deprecated. Please use mxnet.mod.Module instead.\u001b[0m\n",
      "2017-01-12 19:33:05,871 Node[0] Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]\n",
      "2017-01-12 19:33:08,355 Node[0] Batch:       1 fc_backward_weight             9.27798e-06\t\n",
      "2017-01-12 19:33:08,356 Node[0] Batch:       1 fc_backward_weight             8.58008e-06\t\n",
      "2017-01-12 19:33:08,358 Node[0] Batch:       1 fc_backward_weight             8.96261e-06\t\n",
      "2017-01-12 19:33:08,359 Node[0] Batch:       1 fc_backward_weight             7.33611e-06\t\n",
      "2017-01-12 19:33:13,214 Node[0] Epoch[0] Resetting Data Iterator\n",
      "2017-01-12 19:33:13,215 Node[0] Epoch[0] Time cost=7.320\n",
      "2017-01-12 19:33:13,516 Node[0] Epoch[0] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:13,517 Node[0] Epoch[0] Validation-top_k_accuracy_5=0.509115\n",
      "2017-01-12 19:33:15,898 Node[0] Batch:     469 fc_backward_weight             6.76125e-06\t\n",
      "2017-01-12 19:33:15,899 Node[0] Batch:     469 fc_backward_weight             6.54805e-06\t\n",
      "2017-01-12 19:33:15,900 Node[0] Batch:     469 fc_backward_weight             6.80302e-06\t\n",
      "2017-01-12 19:33:15,901 Node[0] Batch:     469 fc_backward_weight             7.39115e-06\t\n",
      "2017-01-12 19:33:21,153 Node[0] Epoch[1] Resetting Data Iterator\n",
      "2017-01-12 19:33:21,154 Node[0] Epoch[1] Time cost=7.637\n",
      "2017-01-12 19:33:21,438 Node[0] Epoch[1] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:21,439 Node[0] Epoch[1] Validation-top_k_accuracy_5=0.504507\n",
      "2017-01-12 19:33:23,894 Node[0] Batch:     937 fc_backward_weight             5.83071e-06\t\n",
      "2017-01-12 19:33:23,895 Node[0] Batch:     937 fc_backward_weight             5.59626e-06\t\n",
      "2017-01-12 19:33:23,895 Node[0] Batch:     937 fc_backward_weight             5.776e-06\t\n",
      "2017-01-12 19:33:23,896 Node[0] Batch:     937 fc_backward_weight             6.28738e-06\t\n",
      "2017-01-12 19:33:28,578 Node[0] Epoch[2] Resetting Data Iterator\n",
      "2017-01-12 19:33:28,580 Node[0] Epoch[2] Time cost=7.139\n",
      "2017-01-12 19:33:28,870 Node[0] Epoch[2] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:28,871 Node[0] Epoch[2] Validation-top_k_accuracy_5=0.510216\n",
      "2017-01-12 19:33:31,294 Node[0] Batch:    1405 fc_backward_weight             4.951e-06\t\n",
      "2017-01-12 19:33:31,296 Node[0] Batch:    1405 fc_backward_weight             4.72836e-06\t\n",
      "2017-01-12 19:33:31,299 Node[0] Batch:    1405 fc_backward_weight             4.8514e-06\t\n",
      "2017-01-12 19:33:31,302 Node[0] Batch:    1405 fc_backward_weight             5.26915e-06\t\n",
      "2017-01-12 19:33:36,266 Node[0] Epoch[3] Resetting Data Iterator\n",
      "2017-01-12 19:33:36,267 Node[0] Epoch[3] Time cost=7.395\n",
      "2017-01-12 19:33:36,576 Node[0] Epoch[3] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:36,577 Node[0] Epoch[3] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:33:38,997 Node[0] Batch:    1873 fc_backward_weight             4.22193e-06\t\n",
      "2017-01-12 19:33:38,998 Node[0] Batch:    1873 fc_backward_weight             4.03044e-06\t\n",
      "2017-01-12 19:33:38,999 Node[0] Batch:    1873 fc_backward_weight             4.11877e-06\t\n",
      "2017-01-12 19:33:39,000 Node[0] Batch:    1873 fc_backward_weight             4.45402e-06\t\n",
      "2017-01-12 19:33:44,271 Node[0] Epoch[4] Resetting Data Iterator\n",
      "2017-01-12 19:33:44,272 Node[0] Epoch[4] Time cost=7.695\n",
      "2017-01-12 19:33:44,567 Node[0] Epoch[4] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:44,568 Node[0] Epoch[4] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:33:47,092 Node[0] Batch:    2341 fc_backward_weight             3.64564e-06\t\n",
      "2017-01-12 19:33:47,094 Node[0] Batch:    2341 fc_backward_weight             3.48901e-06\t\n",
      "2017-01-12 19:33:47,094 Node[0] Batch:    2341 fc_backward_weight             3.55765e-06\t\n",
      "2017-01-12 19:33:47,095 Node[0] Batch:    2341 fc_backward_weight             3.82692e-06\t\n",
      "2017-01-12 19:33:52,308 Node[0] Epoch[5] Resetting Data Iterator\n",
      "2017-01-12 19:33:52,309 Node[0] Epoch[5] Time cost=7.740\n",
      "2017-01-12 19:33:52,674 Node[0] Epoch[5] Validation-accuracy=0.105769\n",
      "2017-01-12 19:33:52,675 Node[0] Epoch[5] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:33:55,045 Node[0] Batch:    2809 fc_backward_weight             3.19336e-06\t\n",
      "2017-01-12 19:33:55,046 Node[0] Batch:    2809 fc_backward_weight             3.06777e-06\t\n",
      "2017-01-12 19:33:55,046 Node[0] Batch:    2809 fc_backward_weight             3.12543e-06\t\n",
      "2017-01-12 19:33:55,047 Node[0] Batch:    2809 fc_backward_weight             3.34344e-06\t\n",
      "2017-01-12 19:34:02,004 Node[0] Epoch[6] Resetting Data Iterator\n",
      "2017-01-12 19:34:02,005 Node[0] Epoch[6] Time cost=9.328\n",
      "2017-01-12 19:34:02,295 Node[0] Epoch[6] Validation-accuracy=0.107472\n",
      "2017-01-12 19:34:02,296 Node[0] Epoch[6] Validation-top_k_accuracy_5=0.509014\n",
      "2017-01-12 19:34:04,830 Node[0] Batch:    3277 fc_backward_weight             2.83478e-06\t\n",
      "2017-01-12 19:34:04,831 Node[0] Batch:    3277 fc_backward_weight             2.73443e-06\t\n",
      "2017-01-12 19:34:04,832 Node[0] Batch:    3277 fc_backward_weight             2.78607e-06\t\n",
      "2017-01-12 19:34:04,833 Node[0] Batch:    3277 fc_backward_weight             2.9644e-06\t\n",
      "2017-01-12 19:34:10,903 Node[0] Epoch[7] Resetting Data Iterator\n",
      "2017-01-12 19:34:10,904 Node[0] Epoch[7] Time cost=8.607\n",
      "2017-01-12 19:34:11,209 Node[0] Epoch[7] Validation-accuracy=0.105970\n",
      "2017-01-12 19:34:11,210 Node[0] Epoch[7] Validation-top_k_accuracy_5=0.512620\n",
      "2017-01-12 19:34:13,543 Node[0] Batch:    3745 fc_backward_weight             2.54587e-06\t\n",
      "2017-01-12 19:34:13,544 Node[0] Batch:    3745 fc_backward_weight             2.46527e-06\t\n",
      "2017-01-12 19:34:13,545 Node[0] Batch:    3745 fc_backward_weight             2.51372e-06\t\n",
      "2017-01-12 19:34:13,546 Node[0] Batch:    3745 fc_backward_weight             2.66109e-06\t\n",
      "2017-01-12 19:34:17,928 Node[0] Epoch[8] Resetting Data Iterator\n",
      "2017-01-12 19:34:17,928 Node[0] Epoch[8] Time cost=6.718\n",
      "2017-01-12 19:34:18,225 Node[0] Epoch[8] Validation-accuracy=0.105970\n",
      "2017-01-12 19:34:18,226 Node[0] Epoch[8] Validation-top_k_accuracy_5=0.512620\n",
      "2017-01-12 19:34:20,790 Node[0] Batch:    4213 fc_backward_weight             2.30903e-06\t\n",
      "2017-01-12 19:34:20,791 Node[0] Batch:    4213 fc_backward_weight             2.24373e-06\t\n",
      "2017-01-12 19:34:20,792 Node[0] Batch:    4213 fc_backward_weight             2.29058e-06\t\n",
      "2017-01-12 19:34:20,793 Node[0] Batch:    4213 fc_backward_weight             2.41351e-06\t\n",
      "2017-01-12 19:34:26,309 Node[0] Epoch[9] Resetting Data Iterator\n",
      "2017-01-12 19:34:26,310 Node[0] Epoch[9] Time cost=8.083\n",
      "2017-01-12 19:34:26,629 Node[0] Epoch[9] Validation-accuracy=0.105970\n",
      "2017-01-12 19:34:26,630 Node[0] Epoch[9] Validation-top_k_accuracy_5=0.512620\n"
     ]
    }
   ],
   "source": [
    "# Xavier and sigmoid\n",
    "args = parse_args('xavier', 'xavier_sigmoid')\n",
    "data_shape = (784, )\n",
    "net = get_mlp(\"sigmoid\")\n",
    "\n",
    "# train\n",
    "fit(args, net, get_iterator(data_shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization\n",
    "\n",
    "Now start using TensorBoard:\n",
    "\n",
    "```bash\n",
    "tensorboard --logdir=logs/\n",
    "```\n",
    "\n",
    "![Dashboard](https://github.com/zihaolucky/tensorboard/raw/data/docs/tutorial/mnist/pic1.png)\n",
    "\n",
    "![dist](https://github.com/zihaolucky/tensorboard/raw/data/docs/tutorial/mnist/pic2.png)\n",
    "\n",
    "![hist](https://github.com/zihaolucky/tensorboard/raw/data/docs/tutorial/mnist/pic3.png)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "You might find these materials useful:\n",
    "\n",
    "[1] [Rohan #4: The vanishing gradient problem – A Year of Artificial Intelligence](https://ayearofai.com/rohan-4-the-vanishing-gradient-problem-ec68f76ffb9b#.bojpejg3o)    \n",
    "[2] [On the difficulty of training recurrent and deep neural networks - YouTube](https://www.youtube.com/watch?v=A7poQbTrhxc)    \n",
    "[3] [What is the vanishing gradient problem? - Quora](https://www.quora.com/What-is-the-vanishing-gradient-problem)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
